package com.xiaojie.sharding.sphere.shardingalgorithm;

import com.google.common.collect.Range;
import org.apache.shardingsphere.api.sharding.standard.PreciseShardingAlgorithm;
import org.apache.shardingsphere.api.sharding.standard.PreciseShardingValue;
import org.apache.shardingsphere.api.sharding.standard.RangeShardingAlgorithm;
import org.apache.shardingsphere.api.sharding.standard.RangeShardingValue;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.Collection;

/**
 * @Description:自定义表分片算法 #范围分片算法类名称，用于BETWEEN，可选。该类需实现RangeShardingAlgorithm接口并提供无参数的构造器
 * shardingsphare默认查询只支持=，between and 这种查询，像>,<,>=,<=这种查询目前不支持,
 * 除非通过继承自定义接口RangeShardingAlgorithm实现，否则无法使用>,<,>=,<=。
 * 同时也需要实现PreciseShardingAlgorithm<String>接口
 * @author: yan
 * @date: 2022.03.12
 */
@Component
public class MyTableShardingAlgorithm implements PreciseShardingAlgorithm<String>, RangeShardingAlgorithm<Long> {


    @Override
    public Collection<String> doSharding(Collection<String> availableTargetNames, RangeShardingValue<Long> shardingValue) {
        Range<Long> valueRange = shardingValue.getValueRange();//获得输入的查询条件范围
        String slowerEndpoint = String.valueOf(valueRange.hasLowerBound() ? valueRange.lowerEndpoint() : "");//查询条件下限
        String supperEndpoint = String.valueOf(valueRange.hasUpperBound() ? valueRange.upperEndpoint() : "");//查询条件上限

        //处理只有下限或上限的范围
        long lowerEndpoint = 0;
        long lupperEndpoint = 0;
        if (!slowerEndpoint.isEmpty() && !supperEndpoint.isEmpty()) {
            lowerEndpoint = Math.abs(Long.parseLong(slowerEndpoint));
            lupperEndpoint = Math.abs(Long.parseLong(supperEndpoint));
        } else if (slowerEndpoint.isEmpty() && !supperEndpoint.isEmpty()) {
            lupperEndpoint = Math.abs(Long.parseLong(supperEndpoint));
            lowerEndpoint = 18;
        } else if (!slowerEndpoint.isEmpty() && supperEndpoint.isEmpty()) {
            lowerEndpoint = Math.abs(Long.parseLong(slowerEndpoint));
            lupperEndpoint = 40;
        }

        Collection<String> collect = new ArrayList<>();
//        逐个读取查询范围slowerEndpoint~lupperEndpoint的值,得对应的表名称
        for (long i = lowerEndpoint; i <= lupperEndpoint; i++) {
            for (String each : availableTargetNames) {
                if (each.endsWith("_" + (i % availableTargetNames.size()))) {
                    if (!collect.contains(each)) {
                        collect.add(each);
                    }
                }
            }
        }
        return collect;
    }

    @Override
    public String doSharding(Collection<String> availableTargetNames, PreciseShardingValue<String> shardingValue) {
        for (String each : availableTargetNames) {
            {
                String hashCode = String.valueOf(shardingValue.getValue());//配置文件中，分表字段对应的值，也是查询条件中输入的查询条件
                long segment = Math.abs(Long.parseLong(hashCode)) % availableTargetNames.size();
                if (each.endsWith("_" + segment + "")) {//
                    return each;
                }
            }
        }
        throw new RuntimeException(shardingValue + "没有匹配到表");
    }
}